{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hidden Markov Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "author: Jacob Schreiber <br>\n",
    "contact: jmschreiber91@gmail.com"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Hidden Markov models (HMMs) are the flagship of the pomegranate package in that they have the most features of all of the models and that they were the first algorithm implemented.\n",
    "\n",
    "Hidden Markov models are a form of structured prediction method that are popular for tagging all elements in a sequence with some \"hidden\" state. They can be thought of as extensions of Markov chains where, instead of the probability of the next observation being dependant on the current observation, the probability of the next **hidden state** is dependant on the current hidden state, and the next observation is derived from that hidden state. An example of this can be part of speech tagging, where the observations are words and the hidden states are parts of speech. Each word gets tagged with a part of speech, but dynamic programming is utilized to search through all potential word-tag combinations to identify the best set of tags across the entire sentence.\n",
    "\n",
    "Another perspective of HMMs is that they are an extension on mixture models that includes a transition matrix. Conceptually, a mixture model has a set of \"hidden\" states---the mixture components---and one can calculate the probability that each sample belongs to each component. This approach treats each observations independently. However, like in the part-of-speech example we know that an adjective typically is followed by a noun, and so position in the sequence matters. A HMM adds a transition matrix between the hidden states to incorporate this information across the sequence, allowing for higher probabilities of transitioning from the \"adjective\" hidden state to a noun or verb.\n",
    "\n",
    "pomegranate implements HMMs in a flexible manner that goes beyond what other packages allow. Let's see some examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fri Jan 10 2020 \n",
      "\n",
      "numpy 1.18.1\n",
      "scipy 1.4.1\n",
      "pomegranate 0.12.0\n",
      "\n",
      "compiler   : Clang 10.0.0 (clang-1000.11.45.5)\n",
      "system     : Darwin\n",
      "release    : 17.7.0\n",
      "machine    : x86_64\n",
      "processor  : i386\n",
      "CPU cores  : 4\n",
      "interpreter: 64bit\n"
     ]
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn; seaborn.set_style('whitegrid')\n",
    "import numpy\n",
    "\n",
    "from pomegranate import *\n",
    "\n",
    "numpy.random.seed(0)\n",
    "numpy.set_printoptions(suppress=True)\n",
    "\n",
    "%load_ext watermark\n",
    "%watermark -m -n -p numpy,scipy,pomegranate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CG rich region identification example\n",
    "\n",
    "Lets take the simplified example of CG island detection on a sequence of DNA. DNA is made up of the four canonical nucleotides, abbreviated 'A', 'C', 'G', and 'T'. We can say that regions of the genome that are enriched for nucleotides 'C' and 'G' are 'CG islands', which is a simplification of the real biological concept but sufficient for our example. The issue with identifying these regions is that they are not exclusively made up of the nucleotides 'C' and 'G', but have some 'A's and 'T's scatted amongst them. A simple model that looked for long stretches of C's and G's would not perform well, because it would miss most of the real regions.\n",
    "\n",
    "We can start off by building the model. Because HMMs involve the transition matrix, which is often represented using a graph over the hidden states, building them requires a few more steps that a simple distribution or the mixture model. Our simple model will be composed of two distributions. One distribution wil be a uniform distribution across all four characters and one will have a preference for the nucleotides C and G, while still allowing the nucleotides A and T to be present."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "d1 = DiscreteDistribution({'A': 0.25, 'C': 0.25, 'G': 0.25, 'T': 0.25})\n",
    "d2 = DiscreteDistribution({'A': 0.10, 'C': 0.40, 'G': 0.40, 'T': 0.10})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For the HMM we have to first define states, which are a pair of a distribution and a name."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "s1 = State(d1, name='background')\n",
    "s2 = State(d2, name='CG island')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we define the HMM and pass in the states."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = HiddenMarkovModel()\n",
    "model.add_states(s1, s2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we have to define the transition matrix, which is the probability of going from one hidden state to the next hidden state. In some cases, like this one, there are high self-loop probabilities, indicating that it's likely that  one will stay in the same hidden state from one observation to the next in the sequence. Other cases have a lower probability of staying in the same state, like the part of speech tagger. A part of the transition matrix is the start probabilities, which is the probability of starting in each of the hidden states. Because we create these transitions one at a time, they are very amenable to sparse transition matrices, where it is impossible to transition from one hidden state to the next."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.add_transition(model.start, s1, 0.5)\n",
    "model.add_transition(model.start, s2, 0.5)\n",
    "model.add_transition(s1, s1, 0.9)\n",
    "model.add_transition(s1, s2, 0.1)\n",
    "model.add_transition(s2, s1, 0.1)\n",
    "model.add_transition(s2, s2, 0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, finally, we need to bake the model in order to finalize the internal structure. Bake must be called when the model has been fully specified."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.bake()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can make predictions on some sequence. Let's create some sequence that has a CG enriched region in the middle and see whether we can identify it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sequence: CGACTACTGACTACTCGCCGACGCGACTGCCGTCTATACTGCGCATACGGC\n",
      "hmm pred: 111111111111111000000000000000011111111111111110000\n"
     ]
    }
   ],
   "source": [
    "seq = numpy.array(list('CGACTACTGACTACTCGCCGACGCGACTGCCGTCTATACTGCGCATACGGC'))\n",
    "\n",
    "hmm_predictions = model.predict(seq)\n",
    "\n",
    "print(\"sequence: {}\".format(''.join(seq)))\n",
    "print(\"hmm pred: {}\".format(''.join(map( str, hmm_predictions))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It looks like it successfully identified a CG island in the middle (the long stretch of 0's) and another shorter one at the end. The predicted integers don't correspond to the order in which states were added to the model, but rather, the order that they exist in the model after a topological sort. More importantly, the model wasn't tricked into thinking that every CG or even pair of CGs was an island. It required many C's and G's to be part of a longer stretch to identify that region as an island. Naturally, the balance of the transition and emission probabilities will heavily influence what regions are detected.\n",
    "\n",
    "Let's say, though, that we want to get rid of that CG island prediction at the end because we don't believe that real islands can occur at the end of the sequence. We can take care of this by adding in an explicit end state that only the non-island hidden state can get to. We enforce that the model has to end in the end state, and if only the non-island state gets there, the sequence of hidden states must end in the non-island state. Here's how:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = HiddenMarkovModel()\n",
    "model.add_states(s1, s2)\n",
    "model.add_transition(model.start, s1, 0.5)\n",
    "model.add_transition(model.start, s2, 0.5)\n",
    "model.add_transition(s1, s1, 0.89 )\n",
    "model.add_transition(s1, s2, 0.10 )\n",
    "model.add_transition(s1, model.end, 0.01)\n",
    "model.add_transition(s2, s1, 0.1 )\n",
    "model.add_transition(s2, s2, 0.9)\n",
    "model.bake()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that all we did was add a transition from `s1` to `model.end` with some low probability. This probability doesn't have to be high if there's only a single transition there, because there's no other possible way of getting to the end state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sequence: CGACTACTGACTACTCGCCGACGCGACTGCCGTCTATACTGCGCATACGGC\n",
      "hmm pred: 111111111111111000000000000000011111111111111111111\n"
     ]
    }
   ],
   "source": [
    "seq = numpy.array(list('CGACTACTGACTACTCGCCGACGCGACTGCCGTCTATACTGCGCATACGGC'))\n",
    "\n",
    "hmm_predictions = model.predict(seq)\n",
    "\n",
    "print(\"sequence: {}\".format(''.join(seq)))\n",
    "print(\"hmm pred: {}\".format(''.join(map( str, hmm_predictions))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This seems far more reasonable. There is a single CG island surrounded by background sequence, and something at the end. If we knew that CG islands cannot occur at the end of sequences, we need only modify the underlying structure of the HMM in order to say that the sequence must end from the background state.\n",
    "\n",
    "In the same way that mixtures could provide probabilistic estimates of class assignments rather than only hard labels, hidden Markov models can do the same. These estimates are the posterior probabilities of belonging to each of the hidden states given the observation, but also given the rest of the sequence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.19841088 0.80158912]\n",
      " [0.32919701 0.67080299]\n",
      " [0.38366073 0.61633927]\n",
      " [0.58044619 0.41955381]\n",
      " [0.69075524 0.30924476]\n",
      " [0.74653183 0.25346817]\n",
      " [0.76392808 0.23607192]]\n"
     ]
    }
   ],
   "source": [
    "print(model.predict_proba(seq)[12:19])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see here the transition from the first non-island region to the middle island region, with high probabilities in one column turning into high probabilities in the other column. The `predict` method is just taking the most likely element, the maximum-a-posteriori estimate."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In addition to using the forward-backward algorithm to just calculate posterior probabilities for each observation, we can count the number of transitions that are predicted to occur between the hidden states."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[15.78100555  2.89559314  0.          0.        ]\n",
      " [ 2.41288774 28.91051356  0.          1.        ]\n",
      " [ 0.4827054   0.5172946   0.          0.        ]\n",
      " [ 0.          0.          0.          0.        ]]\n"
     ]
    }
   ],
   "source": [
    "trans, ems = model.forward_backward(seq)\n",
    "print(trans)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the transition table, which has the soft count of the number of transitions across an edge in the model given a single sequence. It is a square matrix of size equal to the number of states (including start and end state), with number of transitions from (row_id) to (column_id). This is exemplified by the 1.0 in the first row, indicating that there is one transition from background state to the end state, as that's the only way to reach the end state. However, the third (or fourth, depending on ordering) row is the transitions from the start state, and it only slightly favors the background state. These counts are not normalized to the length of the input sequence, but can easily be done so by dividing by row sums, column sums, or entire table sums, depending on your application.\n",
    "\n",
    "A possible reason not to normalize is to run several sequences through and add up their tables, because normalizing in the end and extracting some domain knowledge. It is extremely useful in practice. For example, we can see that there is an expectation of ~2.9 transitions from CG island to background, and ~2.4 from background to CG island. This could be used to infer that there are ~2-3 edges, which makes sense if you consider that the start and end of the sequence seem like they might be part of the CG island states except for the strict transition probabilities used (look at the first few rows of the emission table above.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sequence Alignment Example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets move on to a more complicated structure, that of a profile HMM. A profile HMM is used to align a sequence to a reference 'profile', where the reference profile can either be a single sequence, or an alignment of many sequences (such as a reference genome). In essence, this profile has a 'match' state for every position in the reference profile, and 'insert' state, and a 'delete' state. The insert state allows the external sequence to have an insertion into the sequence without throwing off the entire alignment, such as the following:\n",
    "\n",
    "`ACCG : Sequence` <br>\n",
    "`|| |` <br>\n",
    "`AC-G : Reference`\n",
    "\n",
    "or a deletion, which is the opposite:\n",
    "\n",
    "`A-G : Sequence` <br>\n",
    "`| |` <br>\n",
    "`ACG : Reference`\n",
    "\n",
    "The bars in the middle refer to a perfect match, whereas the lack of a bar means either a deletion/insertion, or a mismatch. A mismatch is where two positions are aligned together, but do not match. This models the biological phenomena of mutation, where one nucleotide can convert to another over time. It is usually more likely in biological sequences that this type of mutation occurs than that the nucleotide was deleted from the sequence (shifting all nucleotides over by one) and then another was inserted at the exact location (moving all nucleotides over again). Since we are using a probabilistic model, we get to define these probabilities through the use of distributions! If we want to model mismatches, we can just set our 'match' state to have an appropriate distribution with non-zero probabilities over mismatches. \n",
    "\n",
    "Lets now create a three nucleotide profile HMM, which models the sequence 'ACT'. We will fuzz this a little bit in the match states, pretending to have some prior information about what mutations occur at each position. If you don't have any information, setting a uniform, small, value over the other values is usually okay."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = HiddenMarkovModel( \"Global Alignment\")\n",
    "\n",
    "# Define the distribution for insertions\n",
    "i_d = DiscreteDistribution( { 'A': 0.25, 'C': 0.25, 'G': 0.25, 'T': 0.25 } )\n",
    "\n",
    "# Create the insert states\n",
    "i0 = State( i_d, name=\"I0\" )\n",
    "i1 = State( i_d, name=\"I1\" )\n",
    "i2 = State( i_d, name=\"I2\" )\n",
    "i3 = State( i_d, name=\"I3\" )\n",
    "\n",
    "# Create the match states\n",
    "m1 = State( DiscreteDistribution({ \"A\": 0.95, 'C': 0.01, 'G': 0.01, 'T': 0.02 }) , name=\"M1\" )\n",
    "m2 = State( DiscreteDistribution({ \"A\": 0.003, 'C': 0.99, 'G': 0.003, 'T': 0.004 }) , name=\"M2\" )\n",
    "m3 = State( DiscreteDistribution({ \"A\": 0.01, 'C': 0.01, 'G': 0.01, 'T': 0.97 }) , name=\"M3\" )\n",
    "\n",
    "# Create the delete states\n",
    "d1 = State( None, name=\"D1\" )\n",
    "d2 = State( None, name=\"D2\" )\n",
    "d3 = State( None, name=\"D3\" )\n",
    "\n",
    "# Add all the states to the model\n",
    "model.add_states( [i0, i1, i2, i3, m1, m2, m3, d1, d2, d3 ] )\n",
    "\n",
    "# Create transitions from match states\n",
    "model.add_transition( model.start, m1, 0.9 )\n",
    "model.add_transition( model.start, i0, 0.1 )\n",
    "model.add_transition( m1, m2, 0.9 )\n",
    "model.add_transition( m1, i1, 0.05 )\n",
    "model.add_transition( m1, d2, 0.05 )\n",
    "model.add_transition( m2, m3, 0.9 )\n",
    "model.add_transition( m2, i2, 0.05 )\n",
    "model.add_transition( m2, d3, 0.05 )\n",
    "model.add_transition( m3, model.end, 0.9 )\n",
    "model.add_transition( m3, i3, 0.1 )\n",
    "\n",
    "# Create transitions from insert states\n",
    "model.add_transition( i0, i0, 0.70 )\n",
    "model.add_transition( i0, d1, 0.15 )\n",
    "model.add_transition( i0, m1, 0.15 )\n",
    "\n",
    "model.add_transition( i1, i1, 0.70 )\n",
    "model.add_transition( i1, d2, 0.15 )\n",
    "model.add_transition( i1, m2, 0.15 )\n",
    "\n",
    "model.add_transition( i2, i2, 0.70 )\n",
    "model.add_transition( i2, d3, 0.15 )\n",
    "model.add_transition( i2, m3, 0.15 )\n",
    "\n",
    "model.add_transition( i3, i3, 0.85 )\n",
    "model.add_transition( i3, model.end, 0.15 )\n",
    "\n",
    "# Create transitions from delete states\n",
    "model.add_transition( d1, d2, 0.15 )\n",
    "model.add_transition( d1, i1, 0.15 )\n",
    "model.add_transition( d1, m2, 0.70 ) \n",
    "\n",
    "model.add_transition( d2, d3, 0.15 )\n",
    "model.add_transition( d2, i2, 0.15 )\n",
    "model.add_transition( d2, m3, 0.70 )\n",
    "\n",
    "model.add_transition( d3, i3, 0.30 )\n",
    "model.add_transition( d3, model.end, 0.70 )\n",
    "\n",
    "# Call bake to finalize the structure of the model.\n",
    "model.bake()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now lets try to align some sequences to it and see what happens!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequence: 'ACT'  -- Log Probability: -0.5132449003570658 -- Path: M1 M2 M3\n",
      "Sequence: 'GGC'  -- Log Probability: -11.048101241343396 -- Path: I0 I0 D1 M2 D3\n",
      "Sequence: 'GAT'  -- Log Probability: -9.125519674022627 -- Path: I0 M1 D2 M3\n",
      "Sequence: 'ACC'  -- Log Probability: -5.0879558788604475 -- Path: M1 M2 M3\n"
     ]
    }
   ],
   "source": [
    "for sequence in map( list, ('ACT', 'GGC', 'GAT', 'ACC') ):\n",
    "    logp, path = model.viterbi( sequence )\n",
    "    print(\"Sequence: '{}'  -- Log Probability: {} -- Path: {}\".format(\n",
    "        ''.join( sequence ), logp, \" \".join( state.name for idx, state in path[1:-1] ) ))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The first and last sequence are entirely matches, meaning that it thinks the most likely alignment between the profile ACT and ACT is A-A, C-C, and T-T, which makes sense, and the most likely alignment between ACT and ACC is A-A, C-C, and T-C, which includes a mismatch. Essentially, it's more likely that there's a T-C mismatch at the end then that there was a deletion of a T at the end of the sequence, and a separate insertion of a C.\n",
    "\n",
    "The two middle sequences don't match very well, as expected! G's are not very likely in this profile at all. It predicts that the two G's are inserts, and that the C matches the C in the profile, before hitting the delete state because it can't emit a T. The third sequence thinks that the G is an insert, as expected, and then aligns the A and T in the sequence to the A and T in the master sequence, missing the middle C in the profile.\n",
    "\n",
    "By using deletes, we can handle other sequences which are shorter than three characters. Lets look at some more sequences of different lengths."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequence: 'A'  -- Log Probability: -5.406181012423981 -- Path: M1 D2 D3\n",
      "Sequence: 'GA'  -- Log Probability: -10.88681993576597 -- Path: I0 M1 D2 D3\n",
      "Sequence: 'AC'  -- Log Probability: -3.6244718790494277 -- Path: M1 M2 D3\n",
      "Sequence: 'AT'  -- Log Probability: -3.644880750680635 -- Path: M1 D2 M3\n",
      "Sequence: 'ATCC'  -- Log Probability: -10.674332964640293 -- Path: M1 D2 M3 I3 I3\n",
      "Sequence: 'ACGTG'  -- Log Probability: -10.393824835172445 -- Path: M1 M2 I2 I2 I2 D3\n",
      "Sequence: 'ATTT'  -- Log Probability: -8.67126440174503 -- Path: M1 I1 I1 D2 M3\n",
      "Sequence: 'TACCCTC'  -- Log Probability: -16.903451796110275 -- Path: I0 I0 I0 I0 D1 M2 M3 I3\n",
      "Sequence: 'TGTCAACACT'  -- Log Probability: -16.451699654050792 -- Path: I0 I0 I0 I0 I0 I0 I0 M1 M2 M3\n"
     ]
    }
   ],
   "source": [
    "for sequence in map( list, ('A', 'GA', 'AC', 'AT', 'ATCC', 'ACGTG', 'ATTT', 'TACCCTC', 'TGTCAACACT') ):\n",
    "    logp, path = model.viterbi( sequence )\n",
    "    print(\"Sequence: '{}'  -- Log Probability: {} -- Path: {}\".format(\n",
    "        ''.join( sequence ), logp, \" \".join( state.name for idx, state in path[1:-1] ) ))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Again, more of the same expected. You'll notice most of the use of insertion states are at I0, because most of the insertions are at the beginning of the sequence. It's more probable to simply stay in I0 at the beginning instead of go from I0 to D1 to I1, or going to another insert state along there. You'll see other insert states used when insertions occur in other places in the sequence, like 'ATTT' and 'ACGTG'.\n",
    "Now that we have the path, we need to convert it into an alignment, which is significantly more informative to look at."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequence: A, Log Probability: -5.406181012423981\n",
      "ACT\n",
      "A--\n",
      "\n",
      "Sequence: GA, Log Probability: -10.88681993576597\n",
      "-ACT\n",
      "GA--\n",
      "\n",
      "Sequence: AC, Log Probability: -3.6244718790494277\n",
      "ACT\n",
      "AC-\n",
      "\n",
      "Sequence: AT, Log Probability: -3.644880750680635\n",
      "ACT\n",
      "A-T\n",
      "\n",
      "Sequence: ATCC, Log Probability: -10.674332964640293\n",
      "ACT--\n",
      "A-TCC\n",
      "\n",
      "Sequence: ACGTG, Log Probability: -10.393824835172445\n",
      "AC---T\n",
      "ACGTG-\n",
      "\n",
      "Sequence: ATTT, Log Probability: -8.67126440174503\n",
      "A--CT\n",
      "ATT-T\n",
      "\n",
      "Sequence: TACCCTC, Log Probability: -16.903451796110275\n",
      "----ACT-\n",
      "TACC-CTC\n",
      "\n",
      "Sequence: TGTCAACACT, Log Probability: -16.451699654050792\n",
      "-------ACT\n",
      "TGTCAACACT\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def path_to_alignment( x, y, path ):\n",
    "    \"\"\"\n",
    "    This function will take in two sequences, and the ML path which is their alignment,\n",
    "    and insert dashes appropriately to make them appear aligned. This consists only of\n",
    "    adding a dash to the model sequence for every insert in the path appropriately, and\n",
    "    a dash in the observed sequence for every delete in the path appropriately.\n",
    "    \"\"\"\n",
    "    \n",
    "    for i, (index, state) in enumerate( path[1:-1] ):\n",
    "        name = state.name\n",
    "        \n",
    "        if name.startswith( 'D' ):\n",
    "            y = y[:i] + '-' + y[i:]\n",
    "        elif name.startswith( 'I' ):\n",
    "            x = x[:i] + '-' + x[i:]\n",
    "\n",
    "    return x, y\n",
    "\n",
    "for sequence in map( list, ('A', 'GA', 'AC', 'AT', 'ATCC', 'ACGTG', 'ATTT', 'TACCCTC', 'TGTCAACACT') ):\n",
    "    logp, path = model.viterbi( sequence )\n",
    "    x, y = path_to_alignment( 'ACT', ''.join(sequence), path )\n",
    "    \n",
    "    print(\"Sequence: {}, Log Probability: {}\".format( ''.join(sequence), logp ))\n",
    "    print(\"{}\\n{}\".format( x, y ))\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Hidden Markov Models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are two main algorithms for training hidden Markov models-- Baum Welch (structured version of Expectation Maximization), and Viterbi training. Since we don't start off with labels on the data, these are both unsupervised training algorithms. In order to assign labels, Baum Welch uses EM to assign soft labels (weights in this case) to each point belonging to each state, and then using weighted MLE estimates to update the distributions. Viterbi assigns hard labels to each observation using the Viterbi algorithm, and then updates the distributions based on these hard labels.\n",
    "\n",
    "pomegranate is extremely well featured when it comes to regularization methods for training, supporting tied emissions and edges, edge and emission inertia, freezing nodes or edges, edge pseudocounts, and multithreaded training. Lets look at some examples of the following:\n",
    "\n",
    "### Tied Emissions\n",
    "\n",
    "Sometimes we want to say that multiple states model the same phenomena, but are simply at different points in the graph because we are utilizing complicated edge structure. An example is in the example of the global alignment HMM we saw. All insert states represent the same phenomena, which is nature randomly inserting a nucleotide, and this probability should be the same regardless of position. However, we can't simply have a single insert state, or we'd be allowed to transition from any match state to any other match state.\n",
    "\n",
    "You can tie emissions together simply by passing the same distribution object to multiple states. That's it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = NormalDistribution( 5, 2 )\n",
    "\n",
    "s1 = State( d, name=\"Tied1\" )\n",
    "s2 = State( d, name=\"Tied2\" )\n",
    "\n",
    "s3 = State( NormalDistribution( 5, 2 ), name=\"NotTied1\" )\n",
    "s4 = State( NormalDistribution( 5, 2 ), name=\"NotTied2\" )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You have now indicated that these two states are tied, and when training, the weights of all points going to s2 will be added to the weights of all points going to s1 when updating d. As a side note, this is implemented in a computationally efficient manner such that d will only be updated once, not twice (but giving the same result). s3 and s4 are not tied together, because while they have the same distribution, it is not the same python object.\n",
    "\n",
    "### Tied Edges\n",
    "\n",
    "Edges can be tied together for the same reason. If you have a modular structure to your HMM, perhaps you believe this repeating structure doesn't (or shouldn't) have a position specific edge structure. You can do this simply by adding a group when you add transitions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = HiddenMarkovModel()\n",
    "model.add_states( [s1, s2] )\n",
    "model.add_transition( model.start, s1, 0.5, group='a' )\n",
    "model.add_transition( model.start, s2, 0.5, group='b' )\n",
    "model.add_transition( s1, s2, 0.5, group='a' )\n",
    "model.add_transition( s2, s1, 0.5, group='b' )\n",
    "model.bake()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The above model doesn't necessarily make sense, but it shows how simple it is to tie edges as well. You can go ahead and train normally from this point, without needing to change any code.\n",
    "\n",
    "### Inertia\n",
    "\n",
    "The next options are inertia on edges or on distributions. This simply means that you update your parameters as (previous_parameter * inertia) + (new_parameter * (1-inertia) ). It is a way to prevent your updates from overfitting immediately. You can specify this in the train function using either `edge_inertia` or `distribution_inertia`. These default to 0, with 1 being the maximum, meaning that you don't update based on new evidence, the same as freezing a distribution or the edges."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{\n",
       "    \"class\" : \"HiddenMarkovModel\",\n",
       "    \"name\" : \"None\",\n",
       "    \"start\" : {\n",
       "        \"class\" : \"State\",\n",
       "        \"distribution\" : null,\n",
       "        \"name\" : \"None-start\",\n",
       "        \"weight\" : 1.0\n",
       "    },\n",
       "    \"end\" : {\n",
       "        \"class\" : \"State\",\n",
       "        \"distribution\" : null,\n",
       "        \"name\" : \"None-end\",\n",
       "        \"weight\" : 1.0\n",
       "    },\n",
       "    \"states\" : [\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : {\n",
       "                \"class\" : \"Distribution\",\n",
       "                \"name\" : \"NormalDistribution\",\n",
       "                \"parameters\" : [\n",
       "                    4.0000059049,\n",
       "                    1.563474497595802\n",
       "                ],\n",
       "                \"frozen\" : false\n",
       "            },\n",
       "            \"name\" : \"Tied1\",\n",
       "            \"weight\" : 1.0\n",
       "        },\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : {\n",
       "                \"class\" : \"Distribution\",\n",
       "                \"name\" : \"NormalDistribution\",\n",
       "                \"parameters\" : [\n",
       "                    4.0000059049,\n",
       "                    1.563474497595802\n",
       "                ],\n",
       "                \"frozen\" : false\n",
       "            },\n",
       "            \"name\" : \"Tied2\",\n",
       "            \"weight\" : 1.0\n",
       "        },\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : null,\n",
       "            \"name\" : \"None-start\",\n",
       "            \"weight\" : 1.0\n",
       "        },\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : null,\n",
       "            \"name\" : \"None-end\",\n",
       "            \"weight\" : 1.0\n",
       "        }\n",
       "    ],\n",
       "    \"end_index\" : 3,\n",
       "    \"start_index\" : 2,\n",
       "    \"silent_index\" : 2,\n",
       "    \"edges\" : [\n",
       "        [\n",
       "            2,\n",
       "            0,\n",
       "            0.5,\n",
       "            0.5,\n",
       "            0\n",
       "        ],\n",
       "        [\n",
       "            2,\n",
       "            1,\n",
       "            0.5,\n",
       "            0.5,\n",
       "            1\n",
       "        ],\n",
       "        [\n",
       "            0,\n",
       "            1,\n",
       "            1.0,\n",
       "            0.5,\n",
       "            0\n",
       "        ],\n",
       "        [\n",
       "            1,\n",
       "            0,\n",
       "            1.0,\n",
       "            0.5,\n",
       "            1\n",
       "        ]\n",
       "    ],\n",
       "    \"distribution ties\" : [\n",
       "        [\n",
       "            0,\n",
       "            1\n",
       "        ],\n",
       "        [\n",
       "            1,\n",
       "            0\n",
       "        ]\n",
       "    ]\n",
       "}"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit( [[5, 2, 3, 4], [5, 7, 2, 3, 5]], distribution_inertia=0.3, edge_inertia=0.25 )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pseudocounts"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Another way of regularizing your model is to add pseudocounts to your edges (which have non-zero probabilities). When updating your edges in the future, you add this pseudocount to the count of transitions across that edge in the future. This gives a more Bayesian estimate of the edge probability, and is useful if you have a large model and don't expect to cross most of the edges with your training data. An example might be a complicated profile HMM, where you don't expect to see deletes or inserts at all in your training data, but don't want to change from the default values.\n",
    "\n",
    "In pomegranate, pseudocounts default to the initial probabilities, so that if you don't see data, the edge values simply aren't updated. You can define both edge specific pseudocounts when you define the transition. When you train, you must define `use_pseudocount=True`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{\n",
       "    \"class\" : \"HiddenMarkovModel\",\n",
       "    \"name\" : \"None\",\n",
       "    \"start\" : {\n",
       "        \"class\" : \"State\",\n",
       "        \"distribution\" : null,\n",
       "        \"name\" : \"None-start\",\n",
       "        \"weight\" : 1.0\n",
       "    },\n",
       "    \"end\" : {\n",
       "        \"class\" : \"State\",\n",
       "        \"distribution\" : null,\n",
       "        \"name\" : \"None-end\",\n",
       "        \"weight\" : 1.0\n",
       "    },\n",
       "    \"states\" : [\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : {\n",
       "                \"class\" : \"Distribution\",\n",
       "                \"name\" : \"NormalDistribution\",\n",
       "                \"parameters\" : [\n",
       "                    3.967631778032427,\n",
       "                    1.3388904282425669\n",
       "                ],\n",
       "                \"frozen\" : false\n",
       "            },\n",
       "            \"name\" : \"s1\",\n",
       "            \"weight\" : 1.0\n",
       "        },\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : {\n",
       "                \"class\" : \"Distribution\",\n",
       "                \"name\" : \"NormalDistribution\",\n",
       "                \"parameters\" : [\n",
       "                    4.038616148637643,\n",
       "                    1.7942514140960613\n",
       "                ],\n",
       "                \"frozen\" : false\n",
       "            },\n",
       "            \"name\" : \"s2\",\n",
       "            \"weight\" : 1.0\n",
       "        },\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : null,\n",
       "            \"name\" : \"None-start\",\n",
       "            \"weight\" : 1.0\n",
       "        },\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : null,\n",
       "            \"name\" : \"None-end\",\n",
       "            \"weight\" : 1.0\n",
       "        }\n",
       "    ],\n",
       "    \"end_index\" : 3,\n",
       "    \"start_index\" : 2,\n",
       "    \"silent_index\" : 2,\n",
       "    \"edges\" : [\n",
       "        [\n",
       "            2,\n",
       "            0,\n",
       "            0.7883901790047138,\n",
       "            4.2,\n",
       "            null\n",
       "        ],\n",
       "        [\n",
       "            2,\n",
       "            1,\n",
       "            0.21160982099528625,\n",
       "            1.3,\n",
       "            null\n",
       "        ],\n",
       "        [\n",
       "            0,\n",
       "            1,\n",
       "            1.0,\n",
       "            5.2,\n",
       "            null\n",
       "        ],\n",
       "        [\n",
       "            1,\n",
       "            0,\n",
       "            1.0,\n",
       "            0.9,\n",
       "            null\n",
       "        ]\n",
       "    ],\n",
       "    \"distribution ties\" : []\n",
       "}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s1 = State( NormalDistribution( 3, 1 ), name=\"s1\" )\n",
    "s2 = State( NormalDistribution( 6, 2 ), name=\"s2\" )\n",
    "\n",
    "model = HiddenMarkovModel()\n",
    "model.add_states( [s1, s2] )\n",
    "model.add_transition( model.start, s1, 0.5, pseudocount=4.2 )\n",
    "model.add_transition( model.start, s2, 0.5, pseudocount=1.3 )\n",
    "model.add_transition( s1, s2, 0.5, pseudocount=5.2 )\n",
    "model.add_transition( s2, s1, 0.5, pseudocount=0.9 )\n",
    "model.bake()\n",
    "model.fit( [[5, 2, 3, 4], [5, 7, 2, 3, 5]], max_iterations=5, use_pseudocount=True )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The other way is to put a blanket pseudocount on all edges."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{\n",
       "    \"class\" : \"HiddenMarkovModel\",\n",
       "    \"name\" : \"None\",\n",
       "    \"start\" : {\n",
       "        \"class\" : \"State\",\n",
       "        \"distribution\" : null,\n",
       "        \"name\" : \"None-start\",\n",
       "        \"weight\" : 1.0\n",
       "    },\n",
       "    \"end\" : {\n",
       "        \"class\" : \"State\",\n",
       "        \"distribution\" : null,\n",
       "        \"name\" : \"None-end\",\n",
       "        \"weight\" : 1.0\n",
       "    },\n",
       "    \"states\" : [\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : {\n",
       "                \"class\" : \"Distribution\",\n",
       "                \"name\" : \"NormalDistribution\",\n",
       "                \"parameters\" : [\n",
       "                    3.982144723850736,\n",
       "                    1.547455896814475\n",
       "                ],\n",
       "                \"frozen\" : false\n",
       "            },\n",
       "            \"name\" : \"s1\",\n",
       "            \"weight\" : 1.0\n",
       "        },\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : {\n",
       "                \"class\" : \"Distribution\",\n",
       "                \"name\" : \"NormalDistribution\",\n",
       "                \"parameters\" : [\n",
       "                    4.018144756842671,\n",
       "                    1.5793744688756404\n",
       "                ],\n",
       "                \"frozen\" : false\n",
       "            },\n",
       "            \"name\" : \"s2\",\n",
       "            \"weight\" : 1.0\n",
       "        },\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : null,\n",
       "            \"name\" : \"None-start\",\n",
       "            \"weight\" : 1.0\n",
       "        },\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : null,\n",
       "            \"name\" : \"None-end\",\n",
       "            \"weight\" : 1.0\n",
       "        }\n",
       "    ],\n",
       "    \"end_index\" : 3,\n",
       "    \"start_index\" : 2,\n",
       "    \"silent_index\" : 2,\n",
       "    \"edges\" : [\n",
       "        [\n",
       "            2,\n",
       "            0,\n",
       "            0.5007412253170119,\n",
       "            0.5,\n",
       "            null\n",
       "        ],\n",
       "        [\n",
       "            2,\n",
       "            1,\n",
       "            0.49925877468298807,\n",
       "            0.5,\n",
       "            null\n",
       "        ],\n",
       "        [\n",
       "            0,\n",
       "            1,\n",
       "            1.0,\n",
       "            0.5,\n",
       "            null\n",
       "        ],\n",
       "        [\n",
       "            1,\n",
       "            0,\n",
       "            1.0,\n",
       "            0.5,\n",
       "            null\n",
       "        ]\n",
       "    ],\n",
       "    \"distribution ties\" : []\n",
       "}"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s1 = State( NormalDistribution( 3, 1 ), name=\"s1\" )\n",
    "s2 = State( NormalDistribution( 6, 2 ), name=\"s2\" )\n",
    "\n",
    "model = HiddenMarkovModel()\n",
    "model.add_states( [s1, s2] )\n",
    "model.add_transition( model.start, s1, 0.5 )\n",
    "model.add_transition( model.start, s2, 0.5 )\n",
    "model.add_transition( s1, s2, 0.5 )\n",
    "model.add_transition( s2, s1, 0.5 )\n",
    "model.bake()\n",
    "model.fit( [[5, 2, 3, 4], [5, 7, 2, 3, 5]], max_iterations=5, transition_pseudocount=20, use_pseudocount=True )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that there isn't as much of an improvement. This is part of regularization, though. We sacrifice fitting the data exactly in order for our model to generalize better to future data. The majority of the training improvement is likely coming from the emissions better fitting the data, though."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multithreaded Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since pomegranate is implemented in cython, the majority of functions are written with the GIL released. A benefit of doing this is that we can use multithreading in order to make some computationally intensive tasks take less time. However, a downside is that python doesn't play nicely with multithreading, and so there are some cases where training using multithreading can make your model training take significantly longer. I investigate this in an early multithreading pull request <a href=\"https://github.com/jmschrei/pomegranate/pull/30\">here</a>. Things have improved since then, but the gist is that if you have a small model (less than 15 states), it may be detrimental, but the larger your model is, the more it scales towards getting a speed improvement exactly the number of threads you use. You can specify multithreading using the `n_jobs` keyword. All structures in pomegranate are thread safe, so you don't need to worry about race conditions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{\n",
       "    \"class\" : \"HiddenMarkovModel\",\n",
       "    \"name\" : \"None\",\n",
       "    \"start\" : {\n",
       "        \"class\" : \"State\",\n",
       "        \"distribution\" : null,\n",
       "        \"name\" : \"None-start\",\n",
       "        \"weight\" : 1.0\n",
       "    },\n",
       "    \"end\" : {\n",
       "        \"class\" : \"State\",\n",
       "        \"distribution\" : null,\n",
       "        \"name\" : \"None-end\",\n",
       "        \"weight\" : 1.0\n",
       "    },\n",
       "    \"states\" : [\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : {\n",
       "                \"class\" : \"Distribution\",\n",
       "                \"name\" : \"NormalDistribution\",\n",
       "                \"parameters\" : [\n",
       "                    3.200000252045981,\n",
       "                    1.6613245827659173\n",
       "                ],\n",
       "                \"frozen\" : false\n",
       "            },\n",
       "            \"name\" : \"s1\",\n",
       "            \"weight\" : 1.0\n",
       "        },\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : {\n",
       "                \"class\" : \"Distribution\",\n",
       "                \"name\" : \"NormalDistribution\",\n",
       "                \"parameters\" : [\n",
       "                    4.6363634085403405,\n",
       "                    1.431638224527499\n",
       "                ],\n",
       "                \"frozen\" : false\n",
       "            },\n",
       "            \"name\" : \"s2\",\n",
       "            \"weight\" : 1.0\n",
       "        },\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : null,\n",
       "            \"name\" : \"None-start\",\n",
       "            \"weight\" : 1.0\n",
       "        },\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : null,\n",
       "            \"name\" : \"None-end\",\n",
       "            \"weight\" : 1.0\n",
       "        }\n",
       "    ],\n",
       "    \"end_index\" : 3,\n",
       "    \"start_index\" : 2,\n",
       "    \"silent_index\" : 2,\n",
       "    \"edges\" : [\n",
       "        [\n",
       "            2,\n",
       "            0,\n",
       "            4.036978946423165e-07,\n",
       "            0.5,\n",
       "            null\n",
       "        ],\n",
       "        [\n",
       "            2,\n",
       "            1,\n",
       "            0.9999995963021053,\n",
       "            0.5,\n",
       "            null\n",
       "        ],\n",
       "        [\n",
       "            0,\n",
       "            1,\n",
       "            1.0,\n",
       "            0.5,\n",
       "            null\n",
       "        ],\n",
       "        [\n",
       "            1,\n",
       "            0,\n",
       "            1.0,\n",
       "            0.5,\n",
       "            null\n",
       "        ]\n",
       "    ],\n",
       "    \"distribution ties\" : []\n",
       "}"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s1 = State( NormalDistribution( 3, 1 ), name=\"s1\" )\n",
    "s2 = State( NormalDistribution( 6, 2 ), name=\"s2\" )\n",
    "\n",
    "model = HiddenMarkovModel()\n",
    "model.add_states( [s1, s2] )\n",
    "model.add_transition( model.start, s1, 0.5 )\n",
    "model.add_transition( model.start, s2, 0.5 )\n",
    "model.add_transition( s1, s2, 0.5 )\n",
    "model.add_transition( s2, s1, 0.5 )\n",
    "model.bake()\n",
    "model.fit( [[5, 2, 3, 4, 7, 3, 6, 3, 5, 2, 4], [5, 7, 2, 3, 5, 1, 3, 5, 6, 2]], max_iterations=5 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{\n",
       "    \"class\" : \"HiddenMarkovModel\",\n",
       "    \"name\" : \"None\",\n",
       "    \"start\" : {\n",
       "        \"class\" : \"State\",\n",
       "        \"distribution\" : null,\n",
       "        \"name\" : \"None-start\",\n",
       "        \"weight\" : 1.0\n",
       "    },\n",
       "    \"end\" : {\n",
       "        \"class\" : \"State\",\n",
       "        \"distribution\" : null,\n",
       "        \"name\" : \"None-end\",\n",
       "        \"weight\" : 1.0\n",
       "    },\n",
       "    \"states\" : [\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : {\n",
       "                \"class\" : \"Distribution\",\n",
       "                \"name\" : \"NormalDistribution\",\n",
       "                \"parameters\" : [\n",
       "                    3.200000252045981,\n",
       "                    1.6613245827659173\n",
       "                ],\n",
       "                \"frozen\" : false\n",
       "            },\n",
       "            \"name\" : \"s1\",\n",
       "            \"weight\" : 1.0\n",
       "        },\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : {\n",
       "                \"class\" : \"Distribution\",\n",
       "                \"name\" : \"NormalDistribution\",\n",
       "                \"parameters\" : [\n",
       "                    4.6363634085403405,\n",
       "                    1.431638224527499\n",
       "                ],\n",
       "                \"frozen\" : false\n",
       "            },\n",
       "            \"name\" : \"s2\",\n",
       "            \"weight\" : 1.0\n",
       "        },\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : null,\n",
       "            \"name\" : \"None-start\",\n",
       "            \"weight\" : 1.0\n",
       "        },\n",
       "        {\n",
       "            \"class\" : \"State\",\n",
       "            \"distribution\" : null,\n",
       "            \"name\" : \"None-end\",\n",
       "            \"weight\" : 1.0\n",
       "        }\n",
       "    ],\n",
       "    \"end_index\" : 3,\n",
       "    \"start_index\" : 2,\n",
       "    \"silent_index\" : 2,\n",
       "    \"edges\" : [\n",
       "        [\n",
       "            2,\n",
       "            0,\n",
       "            4.036978946423165e-07,\n",
       "            0.5,\n",
       "            null\n",
       "        ],\n",
       "        [\n",
       "            2,\n",
       "            1,\n",
       "            0.9999995963021053,\n",
       "            0.5,\n",
       "            null\n",
       "        ],\n",
       "        [\n",
       "            0,\n",
       "            1,\n",
       "            1.0,\n",
       "            0.5,\n",
       "            null\n",
       "        ],\n",
       "        [\n",
       "            1,\n",
       "            0,\n",
       "            1.0,\n",
       "            0.5,\n",
       "            null\n",
       "        ]\n",
       "    ],\n",
       "    \"distribution ties\" : []\n",
       "}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s1 = State( NormalDistribution( 3, 1 ), name=\"s1\" )\n",
    "s2 = State( NormalDistribution( 6, 2 ), name=\"s2\" )\n",
    "\n",
    "model = HiddenMarkovModel()\n",
    "model.add_states( [s1, s2] )\n",
    "model.add_transition( model.start, s1, 0.5 )\n",
    "model.add_transition( model.start, s2, 0.5 )\n",
    "model.add_transition( s1, s2, 0.5 )\n",
    "model.add_transition( s2, s1, 0.5 )\n",
    "model.bake()\n",
    "model.fit( [[5, 2, 3, 4, 7, 3, 6, 3, 5, 2, 4], [5, 7, 2, 3, 5, 1, 3, 5, 6, 2]], max_iterations=5, n_jobs=4 )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Serialization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "General Mixture Models support serialization to JSONs using `to_json()` and `from_json( json )`. This is useful is you want to train a GMM on large amounts of data, taking a significant amount of time, and then use this model in the future without having to repeat this computationally intensive step (sounds familiar by now). Lets look at the original CG island model, since it's significantly smaller."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq = list('CGACTACTGACTACTCGCCGACGCGACTGCCGTCTATACTGCGCATACGGC')\n",
    "\n",
    "d1 = DiscreteDistribution({'A': 0.25, 'C': 0.25, 'G': 0.25, 'T': 0.25})\n",
    "d2 = DiscreteDistribution({'A': 0.10, 'C': 0.40, 'G': 0.40, 'T': 0.10})\n",
    "\n",
    "s1 = State( d1, name='background' )\n",
    "s2 = State( d2, name='CG island' )\n",
    "\n",
    "hmm = HiddenMarkovModel()\n",
    "hmm.add_states(s1, s2)\n",
    "hmm.add_transition( hmm.start, s1, 0.5 )\n",
    "hmm.add_transition( hmm.start, s2, 0.5 )\n",
    "hmm.add_transition( s1, s1, 0.5 )\n",
    "hmm.add_transition( s1, s2, 0.5 )\n",
    "hmm.add_transition( s2, s1, 0.5 )\n",
    "hmm.add_transition( s2, s2, 0.5 )\n",
    "hmm.bake()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "    \"class\" : \"HiddenMarkovModel\",\n",
      "    \"name\" : \"None\",\n",
      "    \"start\" : {\n",
      "        \"class\" : \"State\",\n",
      "        \"distribution\" : null,\n",
      "        \"name\" : \"None-start\",\n",
      "        \"weight\" : 1.0\n",
      "    },\n",
      "    \"end\" : {\n",
      "        \"class\" : \"State\",\n",
      "        \"distribution\" : null,\n",
      "        \"name\" : \"None-end\",\n",
      "        \"weight\" : 1.0\n",
      "    },\n",
      "    \"states\" : [\n",
      "        {\n",
      "            \"class\" : \"State\",\n",
      "            \"distribution\" : {\n",
      "                \"class\" : \"Distribution\",\n",
      "                \"dtype\" : \"str\",\n",
      "                \"name\" : \"DiscreteDistribution\",\n",
      "                \"parameters\" : [\n",
      "                    {\n",
      "                        \"A\" : 0.1,\n",
      "                        \"C\" : 0.4,\n",
      "                        \"G\" : 0.4,\n",
      "                        \"T\" : 0.1\n",
      "                    }\n",
      "                ],\n",
      "                \"frozen\" : false\n",
      "            },\n",
      "            \"name\" : \"CG island\",\n",
      "            \"weight\" : 1.0\n",
      "        },\n",
      "        {\n",
      "            \"class\" : \"State\",\n",
      "            \"distribution\" : {\n",
      "                \"class\" : \"Distribution\",\n",
      "                \"dtype\" : \"str\",\n",
      "                \"name\" : \"DiscreteDistribution\",\n",
      "                \"parameters\" : [\n",
      "                    {\n",
      "                        \"A\" : 0.25,\n",
      "                        \"C\" : 0.25,\n",
      "                        \"G\" : 0.25,\n",
      "                        \"T\" : 0.25\n",
      "                    }\n",
      "                ],\n",
      "                \"frozen\" : false\n",
      "            },\n",
      "            \"name\" : \"background\",\n",
      "            \"weight\" : 1.0\n",
      "        },\n",
      "        {\n",
      "            \"class\" : \"State\",\n",
      "            \"distribution\" : null,\n",
      "            \"name\" : \"None-start\",\n",
      "            \"weight\" : 1.0\n",
      "        },\n",
      "        {\n",
      "            \"class\" : \"State\",\n",
      "            \"distribution\" : null,\n",
      "            \"name\" : \"None-end\",\n",
      "            \"weight\" : 1.0\n",
      "        }\n",
      "    ],\n",
      "    \"end_index\" : 3,\n",
      "    \"start_index\" : 2,\n",
      "    \"silent_index\" : 2,\n",
      "    \"edges\" : [\n",
      "        [\n",
      "            2,\n",
      "            1,\n",
      "            0.5,\n",
      "            0.5,\n",
      "            null\n",
      "        ],\n",
      "        [\n",
      "            2,\n",
      "            0,\n",
      "            0.5,\n",
      "            0.5,\n",
      "            null\n",
      "        ],\n",
      "        [\n",
      "            1,\n",
      "            1,\n",
      "            0.5,\n",
      "            0.5,\n",
      "            null\n",
      "        ],\n",
      "        [\n",
      "            1,\n",
      "            0,\n",
      "            0.5,\n",
      "            0.5,\n",
      "            null\n",
      "        ],\n",
      "        [\n",
      "            0,\n",
      "            1,\n",
      "            0.5,\n",
      "            0.5,\n",
      "            null\n",
      "        ],\n",
      "        [\n",
      "            0,\n",
      "            0,\n",
      "            0.5,\n",
      "            0.5,\n",
      "            null\n",
      "        ]\n",
      "    ],\n",
      "    \"distribution ties\" : []\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print(hmm.to_json())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-69.70121909739684\n"
     ]
    }
   ],
   "source": [
    "seq = list('CGACTACTGACTACTCGCCGACGCGACTGCCGTCTATACTGCGCATACGGC')\n",
    "print(hmm.log_probability( seq ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-69.70121909739684\n"
     ]
    }
   ],
   "source": [
    "hmm_2 = HiddenMarkovModel.from_json( hmm.to_json() )\n",
    "print(hmm_2.log_probability( seq ))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
